{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3612jvsc74a57bd0eb19a231d6cdeddfa7d782706fd68cae3c49b44aaf8515299fdcf94304a50177",
   "display_name": "Python 3.6.12 64-bit ('tf2': conda)"
  },
  "metadata": {
   "interpreter": {
    "hash": "eb19a231d6cdeddfa7d782706fd68cae3c49b44aaf8515299fdcf94304a50177"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Inference Demo for Resnet-50 FP32/INT8 Models (With ONNXRuntime)"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Overview\r\n",
    "\r\n",
    "This notebook shows how to use the trained models to do inference in ONNXRuntime. Please install the prerequisite packages if not already installed."
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Prerequisites\r\n",
    "\r\n",
    "* Protobuf compiler - `sudo apt-get install protobuf-compiler libprotoc-dev` (required for ONNX. This will work for any linux system. For detailed installation guidelines head over to [ONNX documentation](https://github.com/onnx/onnx#installation))\r\n",
    "* ONNX - `pip install onnx`\r\n",
    "* ONNXRuntime - `pip install onnxruntime`\r\n",
    "* matplotlib - `pip install matplotlib`\r\n",
    "* PIL - `pip install Pillow`\r\n",
    "* numpy - `pip install numpy`\r\n",
    "* cv2 - `pip install opencv-python`\r\n",
    "\r\n",
    "In order to do inference with a python script: \r\n",
    "* Generate the script : In Jupyter Notebook browser, go to File -> Download as -> Python (.py)\r\n",
    "* Run the script: `python onnxrt_inference.py`"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Import dependencies\n",
    "\n",
    "Verify that all dependencies are installed using the cell below. Continue if no errors encountered, warnings can be ignored."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "source": [
    "import onnx\r\n",
    "import numpy as np\r\n",
    "import onnxruntime as ort\r\n",
    "from PIL import Image\r\n",
    "import cv2\r\n",
    "import matplotlib.pyplot as plt"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Prepare image and label file\r\n",
    "\r\n",
    "Download image: \r\n",
    "`wget 'https://s3.amazonaws.com/model-server/inputs/kitten.jpg'`\r\n",
    "\r\n",
    "Download label file:\r\n",
    "`wget 'https://s3.amazonaws.com/onnx-model-zoo/synset.txt'`"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "source": [
    "with open('synset.txt', 'r') as f:\r\n",
    "    labels = [l.rstrip() for l in f]"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Import ONNX model\n",
    "\n",
    "Import an onnx model"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "model_path = 'resnet50-v1-12.onnx'\r\n",
    "model = onnx.load(model_path)\r\n",
    "# Start from ORT 1.10, ORT requires explicitly setting the providers parameter if you want to use execution providers\n",
    "# other than the default CPU provider (as opposed to the previous behavior of providers getting set/registered by default\n",
    "# based on the build flags) when instantiating InferenceSession.\n",
    "# For example, if NVIDIA GPU is available and ORT Python package is built with CUDA, then call API as following:\n",
    "# onnxruntime.InferenceSession(path/to/model, providers=['CUDAExecutionProvider']).\n",
    "session = ort.InferenceSession(model.SerializeToString())"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Read image\n",
    "\n",
    "`get_image(path, show=False)` : Read and show the image taking the `path` as input"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "def get_image(path, show=False):\r\n",
    "    with Image.open(path) as img:\r\n",
    "        img = np.array(img.convert('RGB'))\r\n",
    "    if show:\r\n",
    "        plt.imshow(img)\r\n",
    "        plt.axis('off')\r\n",
    "    return img"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Preprocess image\n",
    "\n",
    "`preprocess(img)` : Preprocess inference image -> scale to 0~1, resize to 256x256, take center crop of 224x224, normalize image, transpose to NCHW format, cast to float32 and add a dimension to batchify the image"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "def preprocess(img):\r\n",
    "    img = img / 255.\r\n",
    "    img = cv2.resize(img, (256, 256))\r\n",
    "    h, w = img.shape[0], img.shape[1]\r\n",
    "    y0 = (h - 224) // 2\r\n",
    "    x0 = (w - 224) // 2\r\n",
    "    img = img[y0 : y0+224, x0 : x0+224, :]\r\n",
    "    img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]\r\n",
    "    img = np.transpose(img, axes=[2, 0, 1])\r\n",
    "    img = img.astype(np.float32)\r\n",
    "    img = np.expand_dims(img, axis=0)\r\n",
    "    return img"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Predict\n",
    "\n",
    "`predict(path)` : Takes `path` of the input image and flag to display input image and prints 1 top predictions\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "def predict(path):\r\n",
    "    img = get_image(path, show=True)\r\n",
    "    img = preprocess(img)\r\n",
    "    ort_inputs = {session.get_inputs()[0].name: img}\r\n",
    "    preds = session.run(None, ort_inputs)[0]\r\n",
    "    preds = np.squeeze(preds)\r\n",
    "    a = np.argsort(preds)[::-1]\r\n",
    "    print('class=%s ; probability=%f' %(labels[a[0]],preds[a[0]]))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Generate predictions\n",
    "\n",
    "The top 1 class along with the probabilities generated for the image is displayed in the output of the cell below\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# Enter path to the inference image below\r\n",
    "img_path = 'kitten.jpg'\r\n",
    "predict(img_path)"
   ],
   "outputs": [],
   "metadata": {}
  }
 ]
}
